idx-ubyte 文件格式
idx-ubyte 是一种很简单的二进制文件格式,著名的 MNIST 使用的就是该格式。
它由一个 magic-number 和各个维度的长度组成 header,然后是主体数据。magic-number 和维度的长度都是 32 位大端无符号整数。
- idx1-ubyte 的数据有一个维度,magic-number 的值为 0x00000801
- idx3-ubyte 的数据有三个维度,magic-number 的值为 0x00000803
1struct Idx1Ubyte
2{
3 uint32_t magicNumber;
4 uint32_t dim1;
5 uint8_t datas[];
6};
7
8struct Idx3Ubyte
9{
10 uint32_t magicNumber;
11 uint32_t dim1;
12 uint32_t dim2;
13 uint32_t dim3;
14 uint8_t datas[];
15};以 MNIST 为例 :
- train-images.idx3-ubyte 是训练集图片
- 维度 1 的值是 6000,表示包含 6000 张图片
- 维度 2 的值是 28,表示一张图片有 28 行像素
- 维度 3 的值是 28,表示一张图片有 28 列像素
- train-labels.idx1-ubyte 时训练集标注
- 维度 1 的值是 6000,表示包含 6000 个标注
1#ifndef IDX_UBYTE_HPP
2#define IDX_UBYTE_HPP
3
4#include <cstdio>
5#include <cstdint>
6#include <cstring>
7#include <cerrno>
8#include <vector>
9
10template<uint8_t N>
11struct IdxUbyteData
12{
13 uint8_t* data = nullptr;
14 uint32_t dims[N];
15
16 IdxUbyteData() noexcept = default;
17
18 ~IdxUbyteData() noexcept
19 {
20 if (data != nullptr)
21 {
22 delete[] data;
23 data = nullptr;
24 }
25 }
26
27 IdxUbyteData(IdxUbyteData&& src) noexcept
28 {
29 data = src.data;
30 src.data = nullptr;
31 memcpy(dims, src.dims, sizeof(dims));
32 }
33
34 IdxUbyteData(const IdxUbyteData& src) noexcept
35 {
36 memcpy(dims, src.dims, sizeof(dims));
37
38 size_t bytes = 1;
39 for (uint32_t i = 0; i < N; i++)
40 {
41 bytes *= dims[i];
42 }
43
44 data = new uint8_t[bytes];
45 memcpy(data, src.data, bytes);
46 }
47};
48
49template<uint8_t N>
50class IdxUbyte
51{
52public:
53 IdxUbyte() noexcept = default;
54 ~IdxUbyte() noexcept = default;
55
56 bool write(const char* file, const std::vector< IdxUbyteData<N-1> >& dataset) const noexcept
57 {
58 if (dataset.size() == 0)
59 return false;
60
61 FILE* fp = fopen(file, "wb");
62 if (fp == nullptr)
63 {
64 fprintf(stderr, "%s\n", strerror(errno));
65 return false;
66 }
67
68 this->m_write<32>(fp, MagicNumber);
69 this->m_write<32>(fp, dataset.size());
70
71 size_t bytes = 1;
72 for (uint32_t i = 0; i < N-1; i++)
73 {
74 this->m_write<32>(fp, dataset[0].dims[i]);
75 bytes *= dataset[0].dims[i];
76 }
77
78 for (const auto& data : dataset)
79 {
80 if (fwrite(data.data, 1, bytes, fp) < bytes)
81 {
82 fprintf(stderr, "%s\n", strerror(errno));
83 }
84 }
85
86 fclose(fp);
87 return true;
88 }
89
90 std::vector< IdxUbyteData<N-1> > read(const char* file) const noexcept
91 {
92 std::vector< IdxUbyteData<N-1> > ret(0);
93
94 FILE* fp = fopen(file, "rb");
95 if (fp == nullptr)
96 {
97 fprintf(stderr, "%s\n", strerror(errno));
98 return ret;
99 }
100
101 uint32_t magic = this->m_read<32>(fp);
102 if (magic != MagicNumber)
103 {
104 fprintf(stderr, "magic number mismatch: 0x%08x != 0x%08x\n", magic, MagicNumber);
105 fclose(fp);
106 return ret;
107 }
108
109 uint32_t dims[N];
110 for (size_t i = 0; i < N; i++)
111 {
112 dims[i] = this->m_read<32>(fp);
113 printf("dim %zu: %u\n", i, dims[i]);
114 }
115
116 for (uint32_t i = 0; i < dims[0]; i++)
117 {
118 size_t bytes = 1;
119 IdxUbyteData<N-1>& data = ret.emplace_back();
120 for (size_t j = 1; j < N; j++)
121 {
122 data.dims[j-1] = dims[j];
123 bytes *= dims[j];
124 }
125
126 data.data = new uint8_t[bytes];
127 if (fread(data.data, 1, bytes, fp) < bytes)
128 {
129 fprintf(stderr, "%s\n", strerror(errno));
130 }
131 }
132
133 fclose(fp);
134 return ret;
135 }
136
137private:
138 constexpr static const uint32_t MagicNumber = 0x00000800 | N;
139
140 // 大端读
141 template<size_t bits>
142 uint32_t m_read(FILE* fp) const noexcept
143 {
144 uint32_t ret = 0;
145 uint8_t byte = 0;
146
147 for (size_t i = 0; i < bits / 8; i++)
148 {
149 ret <<= 8;
150 if (fread(&byte, 1, 1, fp) < 1)
151 {
152 fprintf(stderr, "%s\n", strerror(errno));
153 }
154 ret |= byte;
155 }
156
157 return ret;
158 }
159
160 // 大端写
161 template<size_t bits>
162 void m_write(FILE* fp, uintmax_t value) const noexcept
163 {
164 constexpr const size_t bytes = bits / 8;
165 uint8_t byte = 0;
166
167 for (size_t i = 1; i <= bytes; i++)
168 {
169 byte = static_cast<uint8_t>(value >> (8 * (bytes - i)));
170 fwrite(&byte, 1, 1, fp);
171 }
172 }
173};
174
175#endif // IDX_UBYTE_HPP